using LinearAlgebra, Distributions, StatsBase, Plots,MLDatasets
using Statistics
using Nemo, Images
using ToeplitzMatrices
using Random
using Kronecker
using StatPlots
using SpecialFunctions: erfc
using NLsolve
using MAT
push!(LOAD_PATH,"data/")
include("utility.jl")
#using PGFPlotsX
#Define the data generation process
# data_type="OFFICE";
data_type="OFFICE_VGG";
#data_type="synthetic";
p=800;
γ=[0.1;1];λ=1;
dataset="Ca-W";
# n_trial=1;
# acc=zeros(n_trial,);
# acc_opt=zeros(n_trial,);
if data_type=="synthetic"
        p=100;
        m=10;k=2;β=0.2;
        cons=rand(m*k,1)+ones(m*k,1);
        ns=convert.(Int,(floor.(cons*p)));
        n_test=convert.(Int,100*ones(m*k,1));
        M,Σ=generate_statistic(k,m,p,β);
        Sfts,Slabels,Tfts,Tlabels,X_test,y_test,X_test1,y_test1=generate_data(p,ns,k,m,β,M,Σ,data_type,n_test)
        y_test1=vec(reduce(hcat,[i*ones(n_test[i],1) for i=1:m]))
elseif data_type=="MNIST"
        p=784;
        k=2;m=2;β=0.5;
        #γ=[1;1];λ=10;
        ns=convert.(Int,[200,200,100,100]);
        n_test=zeros(k*m,1);
        n_test[1]=size(F11_test)[2];
        n_test[2]=size(F12_test)[2];
        n_test[3]=size(F21_test)[2];
        n_test[4]=size(F22_test)[2];
        n_test=convert.(Int,n_test);
        label11=1;label12=4;label21=7;label22=9;
        train_x, train_y = MNIST.traindata();
        train_x_reshape=reshape(train_x,784,60000)
        F11=train_x_reshape[:,train_y.==label11];
        F12=train_x_reshape[:,train_y.==label12];
        F21=train_x_reshape[:,train_y.==label21];
        F22=train_x_reshape[:,train_y.==label22];
        Sfts=convert.(Float64,[F11[:,1:ns[1]] F12[:,1:ns[2]]]);
        Slabels=[ones(ns[1],1);-ones(ns[2],1)];
        Tfts=convert.(Float64,[F21[:,1:ns[3]] F22[:,1:ns[4]]]);
        Tlabels=[ones(ns[3],1);-ones(ns[4],1)];
        test_x,  test_y  = MNIST.testdata()
        test_x_reshape=reshape(test_x,784,10000)
        F11_test=test_x_reshape[:,test_y.==label11];
        F12_test=test_x_reshape[:,test_y.==label12];
        F21_test=test_x_reshape[:,test_y.==label21];
        F22_test=test_x_reshape[:,test_y.==label22];
        X_test=convert.(Float64,[F11_test F12_test]);X_test1=convert.(Float64,[F21_test F22_test]);
        y_test=[ones(size(F11_test)[2],1);-ones(size(F12_test)[2],1)];
        y_test1=[ones(size(F21_test)[2],1);-ones(size(F22_test)[2],1)];
elseif data_type=="OFFICE"
        m=10;k=2;
        #γ=[1;1];λ=10;
        p=800;
        ns=zeros(m*k,1);
        if dataset=="Ca-W"
                fileS=matopen("data/Caltech10_SURF_L10.mat")
                fileT = matopen("data/webcam_SURF_L10.mat")
                Sfts=transpose(read(fileS, "fts"));
                Slabels=read(fileS, "labels");
                Tftstot=read(fileT, "fts");
                Tlabelstot=read(fileT, "labels");
                # Sfts=transpose(read(fileS, "Xt"));
                # Slabels=read(fileS, "Yt");
                # Tftstot=read(fileT, "Xt");
                # Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
                #n_test=zeros()
                # S=load('Caltech10_SURF_L10.mat');
                # T=load('webcam_SURF_L10.mat');
        elseif dataset=="W-Ca"
                # S=load('webcam_SURF_L10.mat');
                # T=load('Caltech10_SURF_L10.mat');
                fileS=matopen("data/webcam_SURF_L10.mat")
                fileT = matopen("data/Caltech10_SURF_L10.mat")
                Sfts=transpose(read(fileS, "fts"));
                # Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "labels");
                # Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "fts");
                Tlabelstot=read(fileT, "labels");
                # Tftstot=read(fileT, "Xt");
                # Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=sort(unique(Slabels));
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="Ca-A"
                # S=load('Caltech10_SURF_L10.mat');
                # T=load('amazon_SURF_L10.mat');
                fileS=matopen("data/Caltech10_zscore_SURF_L10.mat")
                fileT = matopen("data/amazon_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="A-Ca"
                # S=load('amazon_SURF_L10.mat');
                # T=load('Caltech10_SURF_L10.mat');
                fileS=matopen("data/Caltech10_zscore_SURF_L10.mat")
                fileT = matopen("data/webcam_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="Ca-D"
                # S=load('Caltech10_SURF_L10.mat');
                # T=load('dslr_SURF_L10.mat');
                fileS=matopen("data/Caltech10_zscore_SURF_L10.mat")
                fileT = matopen("data/dslr_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xs");
                Tlabelstot=read(fileT, "Ys");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="D-Ca"
                # S=load('dslr_SURF_L10.mat');
                # T=load('Caltech10_SURF_L10.mat');
                fileS=matopen("data/dslr_zscore_SURF_L10.mat")
                fileT = matopen("data/Caltech10_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xs"));
                Slabels=read(fileS, "Ys");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="A-W"
                # S=load('amazon_SURF_L10.mat');
                # T=load('webcam_SURF_L10.mat');
                fileS=matopen("data/amazon_zscore_SURF_L10.mat")
                fileT = matopen("data/webcam_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="W-A"
                # S=load('webcam_SURF_L10.mat');
                # T=load('amazon_SURF_L10.mat');
                fileS=matopen("data/webcam_zscore_SURF_L10.mat")
                fileT = matopen("data/amazon_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="A-D"
                # S=load('amazon_SURF_L10.mat');
                # T=load('dslr_SURF_L10.mat');
                fileS=matopen("data/amazon_zscore_SURF_L10.mat")
                fileT = matopen("data/dslr_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xs");
                Tlabelstot=read(fileT, "Ys");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="D-A"
                # S=load('dslr_SURF_L10.mat');
                # T=load('amazon_SURF_L10.mat');
                fileS=matopen("data/dslr_zscore_SURF_L10.mat")
                fileT = matopen("data/amazon_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xs"));
                Slabels=read(fileS, "Ys");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="W-D"
                # S=load('webcam_SURF_L10.mat');
                # T=load('dslr_SURF_L10.mat');
                fileS=matopen("data/webcam_zscore_SURF_L10.mat")
                fileT = matopen("data/dslr_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xt"));
                Slabels=read(fileS, "Yt");
                Tftstot=read(fileT, "Xs");
                Tlabelstot=read(fileT, "Ys");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="D-W"
                #S=load('dslr_SURF_L10.mat');
                #T=load('webcam_SURF_L10.mat');
                fileS=matopen("data/dslr_zscore_SURF_L10.mat")
                fileT = matopen("data/webcam_zscore_SURF_L10.mat")
                Sfts=transpose(read(fileS, "Xs"));
                Slabels=read(fileS, "Ys");
                Tftstot=read(fileT, "Xt");
                Tlabelstot=read(fileT, "Yt");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        end
        Slabels=sort(Slabels,dims=1);
        Tlabels=sort(Tlabels,dims=1);
        ns=convert.(Int,ns);
        n_test1=[sum(y_test1.==i) for i=1:m];
        n_test2=[sum(y_test.==i) for i=1:m];
        n_test=[n_test2;n_test1];
        X_test1=reduce(hcat,X_test1[:,vec(y_test1.==i)] for i=1:m);
        y_test1=sort(y_test1,dims=1);
        X_test=reduce(hcat,X_test[:,vec(y_test.==i)] for i=1:m);
        y_test=sort(y_test,dims=1);
elseif data_type=="OFFICE_VGG"
        m=10;k=2;
        #γ=[1;1];λ=10;
        p=4096;
        ns=zeros(m*k,1);
        if dataset=="Ca-W"
                fileS=matopen("data/caltech_VGG-FC7.mat")
                fileT = matopen("data/webcam_VGG-FC7.mat")
                # Sfts=transpose(read(fileS, "fts"));
                # Slabels=read(fileS, "labels");
                # Tftstot=read(fileT, "fts");
                # Tlabelstot=read(fileT, "labels");
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS")';
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
                #n_test=zeros()
                # S=load('Caltech10_SURF_L10.mat');
                # T=load('webcam_SURF_L10.mat');
        elseif dataset=="W-Ca"
                # S=load('webcam_SURF_L10.mat');
                # T=load('Caltech10_SURF_L10.mat');
                fileS=matopen("data/webcam_VGG-FC7.mat")
                fileT = matopen("data/Caltech_VGG-FC7.mat")
                # Sfts=transpose(read(fileS, "fts"));
                Sfts=transpose(read(fileS, "FTS"));
                # Slabels=read(fileS, "labels");
                Slabels=read(fileS, "LABELS");
                # Tftstot=read(fileT, "fts");
                # Tlabelstot=read(fileT, "labels");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=sort(unique(Slabels));
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="Ca-A"
                # S=load('Caltech10_SURF_L10.mat');
                # T=load('amazon_SURF_L10.mat');
                fileS=matopen("data/Caltech_VGG-FC7.mat")
                fileT = matopen("data/amazon_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="A-Ca"
                # S=load('amazon_SURF_L10.mat');
                # T=load('Caltech10_SURF_L10.mat');
                fileS=matopen("data/Caltech_VGG-FC7.mat")
                fileT = matopen("data/webcam_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="Ca-D"
                # S=load('Caltech10_SURF_L10.mat');
                # T=load('dslr_SURF_L10.mat');
                fileS=matopen("data/Caltech_VGG-FC7.mat")
                fileT = matopen("data/dslr_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "Xs");
                Tlabelstot=read(fileT, "Ys");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="D-Ca"
                # S=load('dslr_SURF_L10.mat');
                # T=load('Caltech10_SURF_L10.mat');
                fileS=matopen("data/dslr_VGG-FC7.mat")
                fileT = matopen("data/Caltech_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "Xs"));
                Slabels=read(fileS, "Ys");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="A-W"
                # S=load('amazon_SURF_L10.mat');
                # T=load('webcam_SURF_L10.mat');
                fileS=matopen("data/amazon_VGG-FC7.mat")
                fileT = matopen("data/webcam_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="W-A"
                # S=load('webcam_SURF_L10.mat');
                # T=load('amazon_SURF_L10.mat');
                fileS=matopen("data/webcam_VGG-FC7.mat")
                fileT = matopen("data/amazon_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="A-D"
                # S=load('amazon_SURF_L10.mat');
                # T=load('dslr_SURF_L10.mat');
                fileS=matopen("data/amazon_VGG-FC7.mat")
                fileT = matopen("data/dslr_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "Xs");
                Tlabelstot=read(fileT, "Ys");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="D-A"
                # S=load('dslr_SURF_L10.mat');
                # T=load('amazon_SURF_L10.mat');
                fileS=matopen("data/dslr_VGG-FC7.mat")
                fileT = matopen("data/amazon_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "Xs"));
                Slabels=read(fileS, "Ys");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="W-D"
                # S=load('webcam_SURF_L10.mat');
                # T=load('dslr_SURF_L10.mat');
                fileS=matopen("data/webcam_VGG-FC7.mat")
                fileT = matopen("data/dslr_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "FTS"));
                Slabels=read(fileS, "LABELS");
                Tftstot=read(fileT, "Xs");
                Tlabelstot=read(fileT, "Ys");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        elseif dataset=="D-W"
                #S=load('dslr_SURF_L10.mat');
                #T=load('webcam_SURF_L10.mat');
                fileS=matopen("data/dslr_VGG-FC7.mat")
                fileT = matopen("data/webcam_VGG-FC7.mat")
                Sfts=transpose(read(fileS, "Xs"));
                Slabels=read(fileS, "Ys");
                Tftstot=read(fileT, "FTS");
                Tlabelstot=read(fileT, "LABELS");
                Tfts,Tlabels,X_test1,y_test1=partitionTrain(Tftstot,Tlabelstot,0.5);
                X_test=Sfts;y_test=Slabels;
                nlab=unique(Slabels);
                Sfts=reduce(hcat,Sfts[:,vec(Slabels.==i)] for i=1:m);
                Tfts=reduce(hcat,Tfts[:,vec(Tlabels.==i)] for i=1:m);
                for i=1:length(nlab)
                        ns[i]=sum(Slabels.==nlab[i]);
                        ns[i+m]=sum(Tlabels.==nlab[i]);
                end
        end
        Slabels=sort(Slabels,dims=1);
        Tlabels=sort(Tlabels,dims=1);
        ns=convert.(Int,ns);
        n_test1=[sum(y_test1.==i) for i=1:m];
        n_test2=[sum(y_test.==i) for i=1:m];
        n_test=[n_test2;n_test1];
        X_test1=reduce(hcat,X_test1[:,vec(y_test1.==i)] for i=1:m);
        y_test1=sort(y_test1,dims=1);
        X_test=reduce(hcat,X_test[:,vec(y_test.==i)] for i=1:m);
        y_test=sort(y_test,dims=1);
end
#Multi class xtension one versus all
score_th=zeros(m*k,m);score_th_opt=zeros(m*k,m);
variance_th=zeros(m,m*k);variance_th_opt=zeros(m,m*k);
y_opt=zeros(m*k,m);y_opt_opt=zeros(m*k,m);
obj3=zeros(m*k,m);obj4=zeros(m*k,m);
init_order=[ i for i=1:k*m];
n=sum(ns);
nsa=zeros(1,2*k);
nsi=reshape([zeros(1,k);reshape(ns,m,k)],1,k*m+k);
nsi_test=reshape([zeros(1,k);reshape(n_test,m,k)],1,k*m+k);
gx_s=zeros(size(X_test)[2],m);gx_s_opt=zeros(size(X_test)[2],m);
gx_t=zeros(size(X_test1)[2],m);gx_t_opt=zeros(size(X_test1)[2],m);
μ_th=zeros(2*k,m);μ_th_opt=zeros(2*k,m);
μ_emp=zeros(2*k,m);μ_emp_opt=zeros(2*k,m);
σ_th=zeros(k,m);σ_th_opt=zeros(k,m);
σ_emp=zeros(2*k,m);σ_emp_opt=zeros(2*k,m);
error_source_emp=zeros(1,m);error_source_emp_opt=zeros(1,m);
error_target_emp=zeros(1,m);error_target_emp_opt=zeros(1,m);
error_source_th=zeros(1,m);error_source_th_opt=zeros(1,m);
error_target_th=zeros(1,m);error_target_th_opt=zeros(1,m);
X=[Sfts Tfts];
#gx_s_opt[:,i],gx_t_opt[:,i],μ_th_opt[:,i],μ_emp_opt[:,i],σ_th_opt[:,i],σ_emp_opt[:,i],error_source_emp_opt[:,i],error_target_emp_opt[:,i],error_source_th_opt[:,i],error_target_th_opt[:,i]
for i=1:m
        # if i!=1
        #     init_order[m*(1-1)+1]=m*(1-1)+i;init_order[m*(1-1)+i]=m*(1-1)+i-1;
        # end
        # if i!=1
        #         init_order[m*(k-1)+1]=m*(k-1)+i;
        #         init_order[m*(k-1)+i]=m*(k-1)+i-1;
        # end
        # nsn=ns[init_order];
        # J=zeros(n,m*k);
        # for h=1:m*k
        #         J[sum(ns[1:h-1])+1:sum(ns[1:h]),h]=ones(ns[h],1);
        # end
        # tildey=-ones(m*k,1);
        # tildey[i]=1;tildey[i+m]=1;
        # yc=J*tildey;
        # nsf=[nsn[1],sum(nsn[2:m]),nsn[m+1],sum(nsn[m+2:end])];
        # n1=sum(nsn[1:m]);n2=sum(nsn[m+1:end]);
        # Slabels=yc[1:n1];Tlabels=yc[n1+1:end];
        #nst=[10000;20000;10000;20000];
        #nst=[n_test[1]]
        if i!=1
                init_order[m*(1-1)+1]=m*(1-1)+i;init_order[m*(1-1)+i]=m*(1-1)+i-1;
        end
        if i!=1
                init_order[m*(k-1)+1]=m*(k-1)+i;
                init_order[m*(k-1)+i]=m*(k-1)+i-1;
        end
        nsn=ns[init_order];
        nsn_test=n_test[init_order];
        ntots=sum(nsn[2:m]);ntott=sum(nsn[m+2:end]);
        nsa=zeros(2*k,1);
        nsa[1:2:end]=ns[1:m:end];
        nsa[2:2:end]=sum(reshape(ns,m,k),dims=1)'-ns[1:m:end];
        nsa=convert.(Int,nsa);
        Xt2=X[:,sum(nsa[1:2*(k-1)])+1:end];
        X_test2=X_test1;
        orderm=[i for i=1:sum(ns[m*(k-1)+1:k*m])];
        orderm_test=[i for i=1:sum(n_test[m*(k-1)+1:k*m])];
        orderm1=[i for i=1+sum(nsi[(m+1)*(k-1)+1:(m+1)*(k-1)+i]):sum(nsi[(m+1)*(k-1)+1:(m+1)*(k-1)+1+i])];
        orderm1_test=[i for i=1+sum(nsi_test[(m+1)*(k-1)+1:(m+1)*(k-1)+i]):sum(nsi_test[(m+1)*(k-1)+1:(m+1)*(k-1)+1+i])];
        orderm2=orderm;
        orderm2_test=orderm_test;
        orderm1=convert.(Int,orderm1);orderm2=convert.(Int,orderm2);
        orderm1_test=convert.(Int,orderm1_test);orderm2_test=convert.(Int,orderm2_test);
        for i=1:length(orderm1)
                filter!(e->e.!=orderm1[i],orderm2)
        end
        for i=1:length(orderm1_test)
                filter!(e->e.!=orderm1_test[i],orderm2_test)
        end
        #orderm2[orderm1]=[];
        order2=[orderm1;orderm2];
        order2_test=[orderm1_test;orderm2_test];
        X2=Xt2[:,order2];
        X_test11=X_test2[:,order2_test];
        Xt1=X[:,1:nsa[2*(1-1)+1]+nsa[2*(1-1)+2]];
        ordert=[i for i=1:sum(ns[m*(1-1)+1:m])];
        ordert1=[i for i=1+sum(nsi[1+m*(1-1):i+m*(1-1)]):sum(nsi[1+m*(1-1):m*(1-1)+i+1])];ordert2=ordert;
        ordert2=ordert;
        ordert1=convert.(Int,ordert1);ordert2=convert.(Int,ordert2);
        for i=1:length(ordert1)
                filter!(e->e.!=ordert1[i],ordert2)
        end
        #ordert2[ordert1]=[];
        order1=[ordert1;ordert2];
        X1p=Xt1[:,order1];
        n=convert(Int,sum(ns));
        J=zeros(n,m*k);
        for h=1:m*k
                J[sum(ns[1:h-1])+1:sum(ns[1:h]),h]=ones(ns[h],1);
        end
        tildey=-ones(m*k,1);tildey[1:m:end].=1;
        yc=J*tildey;
        n1=sum(nsn[1:m]);n2=sum(nsn[m+1:end]);
        Slabels=yc[1:n1];Tlabels=yc[n1+1:end];
        ne=[nsn[1];sum(nsn[2:m]);nsn[m+1];sum(nsn[m+2:end])];
        nst=[nsn_test[1];sum(nsn_test[2:m]);nsn_test[m+1];sum(nsn_test[m+2:end])];
        Jk=zeros(n,2*k);
        for g=1:2*k
            Jk[sum(ne[1:g-1])+1:sum(ne[1:g]),g]=ones(ne[g],1);
        end
        #[score1(:,i),error_opt,alpha2, b,score_th(:,i),variance_th(i,:),score_emp,var_emp,y_opt(:,i),covar(:,:,i),obj3(:,i)] = MLSSVRTrain_th1_centered_other_class(X1,X2, yc, gamma, lambda,M,Ct,X_test,ns','task',k,nst,i)
        # gx_s[:,i],gx_t[:,i],μ_th[:,i],μ_emp[:,i],σ_th[:,i],σ_emp[:,i],error_source_emp[:,i],error_target_emp[:,i],error_source_th[:,i],error_target_th[:,i],
        # gx_s_opt[:,i],gx_t_opt[:,i],μ_th_opt[:,i],μ_emp_opt[:,i],σ_th_opt[:,i],σ_emp_opt[:,i],error_source_emp_opt[i],error_target_emp_opt[i],error_source_th_opt[i],error_target_th_opt[i].=
        # RMTMTLLSSVM(Sfts,Slabels,Tfts,Tlabels,λ,γ,X_test,X_test1,nsf,n_test)
        # gx_s1,gx_t1,gx_s_opt1,gx_t_opt1,σ_th_opt1=RMTMTLLSSVM_real(X1p,Slabels,X2,Tlabels,λ,γ,X_test,X_test1,ne,i);
        # gx_s[:,i]=gx_s1;gx_t[:,i]=gx_t1;gx_s_opt[:,i]=gx_s_opt1;gx_t_opt[:,i]=gx_t_opt1;#./sqrt(abs(σ_th_opt1[2]));
        # σ_th_opt[:,i]=σ_th_opt1;
        gx_s1,gx_t1,μ_th1,μ_emp1,σ_th1,σ_emp1,error_source_emp1,error_target_emp1,error_source_th1,error_target_th1,
        gx_s_opt1,gx_t_opt1,μ_th_opt1,μ_emp_opt1,σ_th_opt1,σ_emp_opt1,error_source_emp_opt1,error_target_emp_opt1,error_source_th_opt1,error_target_th_opt1=
        RMTMTLLSSVM(X1p,Slabels,X2,Tlabels,λ,γ,X_test,X_test1,ne,nst,i)
        gx_s[:,i]=gx_s1; gx_s_opt[:,i]=gx_s_opt1;
        gx_t[:,i]=gx_t1; gx_t_opt[:,i]=gx_t_opt1./sqrt(abs(σ_th1[2]));
        μ_th[:,i]=μ_th1; μ_th_opt[:,i]=μ_th_opt1;
        μ_emp[:,i]=μ_emp1; μ_emp_opt[:,i]=μ_emp_opt1;
        σ_th[:,i]=σ_th1; σ_th_opt[:,i]=σ_th_opt1;
        σ_emp[:,i]=σ_emp1; σ_emp_opt[:,i]=σ_emp_opt1;
        error_source_emp[i]=error_source_emp1; error_source_emp_opt[i]=error_source_emp_opt1;
        error_target_emp[i]=error_target_emp1; error_target_emp_opt[i]=error_target_emp_opt1;
        error_source_th[i]=error_source_th1; error_source_th_opt[i]=error_source_th_opt1;
        error_target_th[i]=error_target_th1; error_target_th_opt[i]=error_target_th_opt1;
end
#pred=argmax(real(gx_t),dims=2);
# p11=histogram(gx_t[1:n_test[m+1],1],normalize = true,alpha=0.5,bins=100);
# p11=histogram!(gx_t[1:n_test[m+1],2],normalize = true,alpha=0.5,bins=100);
# p11=histogram!(gx_t[1:n_test[m+1],3],normalize = true,alpha=0.5,bins=100);
# # p1=plot!(Normal(μ_th[1],sqrt(σ_th[1])));
# # p1=plot!(Normal(μ_th[2],sqrt(σ_th[1])));
# display(p11)
# for i=1:m-1
#         p11=histogram(gx_t[1+sum(n_test[1+m:i+m]):sum(n_test[m+1:i+m+1]),1],normalize = true,alpha=0.5,bins=100);
#         p11=histogram!(gx_t[1+sum(n_test[1+m:i+m]):sum(n_test[m+1:i+m+1]),2],normalize = true,alpha=0.5,bins=100);
#         p11=histogram!(gx_t[1+sum(n_test[1+m:i+m]):sum(n_test[m+1:i+m+1]),3],normalize = true,alpha=0.5,bins=100);
#         # p1=plot!(Normal(μ_th[1],sqrt(σ_th[1])));
#         # p1=plot!(Normal(μ_th[2],sqrt(σ_th[1])));
#         display(p11)
# end
#
# p12=histogram(gx_t_opt[1:n_test[m+1],1],normalize = true,alpha=0.5,bins=100);
# p12=histogram!(gx_t_opt[1:n_test[m+1],2],normalize = true,alpha=0.5,bins=100);
# p12=histogram!(gx_t_opt[1:n_test[m+1],3],normalize = true,alpha=0.5,bins=100);
# # p1=plot!(Normal(μ_th[1],sqrt(σ_th[1])));
# # p1=plot!(Normal(μ_th[2],sqrt(σ_th[1])));
# display(p12)
# for i=1:m-1
#         p12=histogram(gx_t_opt[1+sum(n_test[1+m:i+m]):sum(n_test[m+1:i+m+1]),1],normalize = true,alpha=0.5,bins=100);
#         p12=histogram!(gx_t_opt[1+sum(n_test[1+m:i+m]):sum(n_test[m+1:i+m+1]),2],normalize = true,alpha=0.5,bins=100);
#         p12=histogram!(gx_t_opt[1+sum(n_test[1+m:i+m]):sum(n_test[m+1:i+m+1]),3],normalize = true,alpha=0.5,bins=100);
#         # p1=plot!(Normal(μ_th[1],sqrt(σ_th[1])));
#         # p1=plot!(Normal(μ_th[2],sqrt(σ_th[1])));
#         display(p12)
# end
pred=convert.(Float64,argmax.(eachrow(real(gx_t))))
#y_test1=vec(reduce(hcat,[i*ones(n_test[i],1) for i=1:m]))
error_1=sum(pred.!=y_test1)/length(y_test1);
pred_opt=convert.(Float64,argmax.(eachrow(real(gx_t_opt))))
error_opt=sum(pred_opt.!=y_test1)/length(y_test1);
acc=1 - error_1
acc_opt=1 - error_opt
println(acc)
println(acc_opt)
